{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'1.0.1.post2'"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "torch.__version__"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 使用PyTorch计算梯度数值"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "PyTorch的Autograd模块实现了深度学习的算法中的向传播求导数，在张量（Tensor类）上的所有操作，Autograd都能为他们自动提供微分，简化了手动计算导数的复杂过程。\n",
    "\n",
    "在0.4以前的版本中，Pytorch使用Variable类来自动计算所有的梯度Variable类主要包含三个属性：\n",
    "data：保存Variable所包含的Tensor；grad：保存data对应的梯度，grad也是个Variable，而不是Tensor，它和data的形状一样；grad_fn：指向一个Function对象，这个Function用来反向传播计算输入的梯度。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "从0.4起, Variable 正式合并入Tensor类, 通过Variable嵌套实现的自动微分功能已经整合进入了Tensor类中。虽然为了代码的兼容性还是可以使用Variable(tensor)这种方式进行嵌套, 但是这个操作其实什么都没做。\n",
    "\n",
    "所以，以后的代码建议直接使用Tensor类进行操作，因为官方文档中已经将Variable设置成过期模块。\n",
    "\n",
    "要想通过Tensor类本身就支持了使用autograd功能，只需要设置.requries_grad=True\n",
    "\n",
    "Variable类中的的grad和grad_fn属性已经整合进入了Tensor类中"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Autograd"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在张量创建时，通过设置 requires_grad 标识为Ture来告诉Pytorch需要对该张量进行自动求导，PyTorch会记录该张量的每一步操作历史并自动计算"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.0403, 0.5633, 0.2561, 0.4064, 0.9596],\n",
       "        [0.6928, 0.1832, 0.5380, 0.6386, 0.8710],\n",
       "        [0.5332, 0.8216, 0.8139, 0.1925, 0.4993],\n",
       "        [0.2650, 0.6230, 0.5945, 0.3230, 0.0752],\n",
       "        [0.0919, 0.4770, 0.4622, 0.6185, 0.2761]], requires_grad=True)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = torch.rand(5, 5, requires_grad=True)\n",
    "x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.2269, 0.7673, 0.8179, 0.5558, 0.0493],\n",
       "        [0.7762, 0.9242, 0.2872, 0.0035, 0.4197],\n",
       "        [0.4322, 0.5281, 0.9001, 0.7276, 0.3218],\n",
       "        [0.5123, 0.6567, 0.9465, 0.0475, 0.9172],\n",
       "        [0.9899, 0.9284, 0.5303, 0.1718, 0.3937]], requires_grad=True)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y = torch.rand(5, 5, requires_grad=True)\n",
    "y"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "PyTorch会自动追踪和记录对与张量的所有操作，当计算完成后调用.backward()方法自动计算梯度并且将计算结果保存到grad属性中。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(25.6487, grad_fn=<SumBackward0>)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "z=torch.sum(x+y)\n",
    "z"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在张量进行操作后，grad_fn已经被赋予了一个新的函数，这个函数引用了一个创建了这个Tensor类的Function对象。\n",
    "Tensor和Function互相连接生成了一个非循环图，它记录并且编码了完整的计算历史。每个张量都有一个.grad_fn属性，如果这个张量是用户手动创建的那么这个张量的grad_fn是None。\n",
    "\n",
    "下面我们来调用反向传播函数，计算其梯度"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 简单的自动求导"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[1., 1., 1., 1., 1.],\n",
      "        [1., 1., 1., 1., 1.],\n",
      "        [1., 1., 1., 1., 1.],\n",
      "        [1., 1., 1., 1., 1.],\n",
      "        [1., 1., 1., 1., 1.]]) tensor([[1., 1., 1., 1., 1.],\n",
      "        [1., 1., 1., 1., 1.],\n",
      "        [1., 1., 1., 1., 1.],\n",
      "        [1., 1., 1., 1., 1.],\n",
      "        [1., 1., 1., 1., 1.]])\n"
     ]
    }
   ],
   "source": [
    "z.backward()\n",
    "print(x.grad,y.grad)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "如果Tensor类表示的是一个标量（即它包含一个元素的张量），则不需要为backward()指定任何参数，但是如果它有更多的元素，则需要指定一个gradient参数，它是形状匹配的张量。\n",
    "以上的 `z.backward()`相当于是`z.backward(torch.tensor(1.))`的简写。\n",
    "这种参数常出现在图像分类中的单标签分类，输出一个标量代表图像的标签。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 复杂的自动求导"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[3.3891e-01, 4.9468e-01, 8.0797e-02, 2.5656e-01, 2.9529e-01],\n",
       "        [7.1946e-01, 1.6977e-02, 1.7965e-01, 3.2656e-01, 1.7665e-01],\n",
       "        [3.1353e-01, 2.2096e-01, 1.2251e+00, 5.5087e-01, 5.9572e-02],\n",
       "        [1.3015e+00, 3.8029e-01, 1.1103e+00, 4.0392e-01, 2.2055e-01],\n",
       "        [8.8726e-02, 6.9701e-01, 8.0164e-01, 9.7221e-01, 4.2239e-04]],\n",
       "       grad_fn=<AddBackward0>)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = torch.rand(5, 5, requires_grad=True)\n",
    "y = torch.rand(5, 5, requires_grad=True)\n",
    "z= x**2+y**3\n",
    "z"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0.2087, 1.3554, 0.5560, 1.0009, 0.9931],\n",
      "        [1.2655, 0.1223, 0.8008, 1.1127, 0.7261],\n",
      "        [1.1052, 0.2579, 1.8006, 0.1544, 0.3646],\n",
      "        [1.8855, 1.2296, 1.9061, 0.9313, 0.0648],\n",
      "        [0.5952, 1.6190, 0.8430, 1.9213, 0.0322]])\n"
     ]
    }
   ],
   "source": [
    "#我们的返回值不是一个标量，所以需要输入一个大小相同的张量作为参数，这里我们用ones_like函数根据x生成一个张量\n",
    "z.backward(torch.ones_like(x))\n",
    "print(x.grad)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们可以使用with torch.no_grad()上下文管理器临时禁止对已设置requires_grad=True的张量进行自动求导。这个方法在测试集计算准确率的时候会经常用到，例如："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "False\n"
     ]
    }
   ],
   "source": [
    "with torch.no_grad():\n",
    "    print((x +y*2).requires_grad)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "使用.no_grad()进行嵌套后，代码不会跟踪历史记录，也就是说保存的这部分记录会减少内存的使用量并且会加快少许的运算速度。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Autograd 过程解析"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "为了说明Pytorch的自动求导原理，我们来尝试分析一下PyTorch的源代码，虽然Pytorch的 Tensor和 TensorBase都是使用CPP来实现的，但是可以使用一些Python的一些方法查看这些对象在Python的属性和状态。\n",
    " Python的 `dir()` 返回参数的属性、方法列表。`z`是一个Tensor变量，看看里面有哪些成员变量。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['__abs__',\n",
       " '__add__',\n",
       " '__and__',\n",
       " '__array__',\n",
       " '__array_priority__',\n",
       " '__array_wrap__',\n",
       " '__bool__',\n",
       " '__class__',\n",
       " '__deepcopy__',\n",
       " '__delattr__',\n",
       " '__delitem__',\n",
       " '__dict__',\n",
       " '__dir__',\n",
       " '__div__',\n",
       " '__doc__',\n",
       " '__eq__',\n",
       " '__float__',\n",
       " '__floordiv__',\n",
       " '__format__',\n",
       " '__ge__',\n",
       " '__getattribute__',\n",
       " '__getitem__',\n",
       " '__gt__',\n",
       " '__hash__',\n",
       " '__iadd__',\n",
       " '__iand__',\n",
       " '__idiv__',\n",
       " '__ilshift__',\n",
       " '__imul__',\n",
       " '__index__',\n",
       " '__init__',\n",
       " '__init_subclass__',\n",
       " '__int__',\n",
       " '__invert__',\n",
       " '__ior__',\n",
       " '__ipow__',\n",
       " '__irshift__',\n",
       " '__isub__',\n",
       " '__iter__',\n",
       " '__itruediv__',\n",
       " '__ixor__',\n",
       " '__le__',\n",
       " '__len__',\n",
       " '__long__',\n",
       " '__lshift__',\n",
       " '__lt__',\n",
       " '__matmul__',\n",
       " '__mod__',\n",
       " '__module__',\n",
       " '__mul__',\n",
       " '__ne__',\n",
       " '__neg__',\n",
       " '__new__',\n",
       " '__nonzero__',\n",
       " '__or__',\n",
       " '__pow__',\n",
       " '__radd__',\n",
       " '__rdiv__',\n",
       " '__reduce__',\n",
       " '__reduce_ex__',\n",
       " '__repr__',\n",
       " '__reversed__',\n",
       " '__rfloordiv__',\n",
       " '__rmul__',\n",
       " '__rpow__',\n",
       " '__rshift__',\n",
       " '__rsub__',\n",
       " '__rtruediv__',\n",
       " '__setattr__',\n",
       " '__setitem__',\n",
       " '__setstate__',\n",
       " '__sizeof__',\n",
       " '__str__',\n",
       " '__sub__',\n",
       " '__subclasshook__',\n",
       " '__truediv__',\n",
       " '__weakref__',\n",
       " '__xor__',\n",
       " '_backward_hooks',\n",
       " '_base',\n",
       " '_cdata',\n",
       " '_coalesced_',\n",
       " '_dimI',\n",
       " '_dimV',\n",
       " '_grad',\n",
       " '_grad_fn',\n",
       " '_indices',\n",
       " '_make_subclass',\n",
       " '_nnz',\n",
       " '_values',\n",
       " '_version',\n",
       " 'abs',\n",
       " 'abs_',\n",
       " 'acos',\n",
       " 'acos_',\n",
       " 'add',\n",
       " 'add_',\n",
       " 'addbmm',\n",
       " 'addbmm_',\n",
       " 'addcdiv',\n",
       " 'addcdiv_',\n",
       " 'addcmul',\n",
       " 'addcmul_',\n",
       " 'addmm',\n",
       " 'addmm_',\n",
       " 'addmv',\n",
       " 'addmv_',\n",
       " 'addr',\n",
       " 'addr_',\n",
       " 'all',\n",
       " 'allclose',\n",
       " 'any',\n",
       " 'apply_',\n",
       " 'argmax',\n",
       " 'argmin',\n",
       " 'argsort',\n",
       " 'as_strided',\n",
       " 'as_strided_',\n",
       " 'asin',\n",
       " 'asin_',\n",
       " 'atan',\n",
       " 'atan2',\n",
       " 'atan2_',\n",
       " 'atan_',\n",
       " 'backward',\n",
       " 'baddbmm',\n",
       " 'baddbmm_',\n",
       " 'bernoulli',\n",
       " 'bernoulli_',\n",
       " 'bincount',\n",
       " 'bmm',\n",
       " 'btrifact',\n",
       " 'btrifact_with_info',\n",
       " 'btrisolve',\n",
       " 'byte',\n",
       " 'cauchy_',\n",
       " 'ceil',\n",
       " 'ceil_',\n",
       " 'char',\n",
       " 'cholesky',\n",
       " 'chunk',\n",
       " 'clamp',\n",
       " 'clamp_',\n",
       " 'clamp_max',\n",
       " 'clamp_max_',\n",
       " 'clamp_min',\n",
       " 'clamp_min_',\n",
       " 'clone',\n",
       " 'coalesce',\n",
       " 'contiguous',\n",
       " 'copy_',\n",
       " 'cos',\n",
       " 'cos_',\n",
       " 'cosh',\n",
       " 'cosh_',\n",
       " 'cpu',\n",
       " 'cross',\n",
       " 'cuda',\n",
       " 'cumprod',\n",
       " 'cumsum',\n",
       " 'data',\n",
       " 'data_ptr',\n",
       " 'dense_dim',\n",
       " 'det',\n",
       " 'detach',\n",
       " 'detach_',\n",
       " 'device',\n",
       " 'diag',\n",
       " 'diag_embed',\n",
       " 'diagflat',\n",
       " 'diagonal',\n",
       " 'digamma',\n",
       " 'digamma_',\n",
       " 'dim',\n",
       " 'dist',\n",
       " 'div',\n",
       " 'div_',\n",
       " 'dot',\n",
       " 'double',\n",
       " 'dtype',\n",
       " 'eig',\n",
       " 'element_size',\n",
       " 'eq',\n",
       " 'eq_',\n",
       " 'equal',\n",
       " 'erf',\n",
       " 'erf_',\n",
       " 'erfc',\n",
       " 'erfc_',\n",
       " 'erfinv',\n",
       " 'erfinv_',\n",
       " 'exp',\n",
       " 'exp_',\n",
       " 'expand',\n",
       " 'expand_as',\n",
       " 'expm1',\n",
       " 'expm1_',\n",
       " 'exponential_',\n",
       " 'fft',\n",
       " 'fill_',\n",
       " 'flatten',\n",
       " 'flip',\n",
       " 'float',\n",
       " 'floor',\n",
       " 'floor_',\n",
       " 'fmod',\n",
       " 'fmod_',\n",
       " 'frac',\n",
       " 'frac_',\n",
       " 'gather',\n",
       " 'ge',\n",
       " 'ge_',\n",
       " 'gels',\n",
       " 'geometric_',\n",
       " 'geqrf',\n",
       " 'ger',\n",
       " 'gesv',\n",
       " 'get_device',\n",
       " 'grad',\n",
       " 'grad_fn',\n",
       " 'gt',\n",
       " 'gt_',\n",
       " 'half',\n",
       " 'hardshrink',\n",
       " 'histc',\n",
       " 'ifft',\n",
       " 'index_add',\n",
       " 'index_add_',\n",
       " 'index_copy',\n",
       " 'index_copy_',\n",
       " 'index_fill',\n",
       " 'index_fill_',\n",
       " 'index_put',\n",
       " 'index_put_',\n",
       " 'index_select',\n",
       " 'indices',\n",
       " 'int',\n",
       " 'inverse',\n",
       " 'irfft',\n",
       " 'is_coalesced',\n",
       " 'is_complex',\n",
       " 'is_contiguous',\n",
       " 'is_cuda',\n",
       " 'is_distributed',\n",
       " 'is_floating_point',\n",
       " 'is_leaf',\n",
       " 'is_nonzero',\n",
       " 'is_pinned',\n",
       " 'is_same_size',\n",
       " 'is_set_to',\n",
       " 'is_shared',\n",
       " 'is_signed',\n",
       " 'is_sparse',\n",
       " 'isclose',\n",
       " 'item',\n",
       " 'kthvalue',\n",
       " 'layout',\n",
       " 'le',\n",
       " 'le_',\n",
       " 'lerp',\n",
       " 'lerp_',\n",
       " 'lgamma',\n",
       " 'lgamma_',\n",
       " 'log',\n",
       " 'log10',\n",
       " 'log10_',\n",
       " 'log1p',\n",
       " 'log1p_',\n",
       " 'log2',\n",
       " 'log2_',\n",
       " 'log_',\n",
       " 'log_normal_',\n",
       " 'log_softmax',\n",
       " 'logdet',\n",
       " 'logsumexp',\n",
       " 'long',\n",
       " 'lt',\n",
       " 'lt_',\n",
       " 'map2_',\n",
       " 'map_',\n",
       " 'masked_fill',\n",
       " 'masked_fill_',\n",
       " 'masked_scatter',\n",
       " 'masked_scatter_',\n",
       " 'masked_select',\n",
       " 'matmul',\n",
       " 'matrix_power',\n",
       " 'max',\n",
       " 'mean',\n",
       " 'median',\n",
       " 'min',\n",
       " 'mm',\n",
       " 'mode',\n",
       " 'mul',\n",
       " 'mul_',\n",
       " 'multinomial',\n",
       " 'mv',\n",
       " 'mvlgamma',\n",
       " 'mvlgamma_',\n",
       " 'name',\n",
       " 'narrow',\n",
       " 'narrow_copy',\n",
       " 'ndimension',\n",
       " 'ne',\n",
       " 'ne_',\n",
       " 'neg',\n",
       " 'neg_',\n",
       " 'nelement',\n",
       " 'new',\n",
       " 'new_empty',\n",
       " 'new_full',\n",
       " 'new_ones',\n",
       " 'new_tensor',\n",
       " 'new_zeros',\n",
       " 'nonzero',\n",
       " 'norm',\n",
       " 'normal_',\n",
       " 'numel',\n",
       " 'numpy',\n",
       " 'orgqr',\n",
       " 'ormqr',\n",
       " 'output_nr',\n",
       " 'permute',\n",
       " 'pin_memory',\n",
       " 'pinverse',\n",
       " 'polygamma',\n",
       " 'polygamma_',\n",
       " 'potrf',\n",
       " 'potri',\n",
       " 'potrs',\n",
       " 'pow',\n",
       " 'pow_',\n",
       " 'prelu',\n",
       " 'prod',\n",
       " 'pstrf',\n",
       " 'put_',\n",
       " 'qr',\n",
       " 'random_',\n",
       " 'reciprocal',\n",
       " 'reciprocal_',\n",
       " 'record_stream',\n",
       " 'register_hook',\n",
       " 'reinforce',\n",
       " 'relu',\n",
       " 'relu_',\n",
       " 'remainder',\n",
       " 'remainder_',\n",
       " 'renorm',\n",
       " 'renorm_',\n",
       " 'repeat',\n",
       " 'requires_grad',\n",
       " 'requires_grad_',\n",
       " 'reshape',\n",
       " 'reshape_as',\n",
       " 'resize',\n",
       " 'resize_',\n",
       " 'resize_as',\n",
       " 'resize_as_',\n",
       " 'retain_grad',\n",
       " 'rfft',\n",
       " 'roll',\n",
       " 'rot90',\n",
       " 'round',\n",
       " 'round_',\n",
       " 'rsqrt',\n",
       " 'rsqrt_',\n",
       " 'scatter',\n",
       " 'scatter_',\n",
       " 'scatter_add',\n",
       " 'scatter_add_',\n",
       " 'select',\n",
       " 'set_',\n",
       " 'shape',\n",
       " 'share_memory_',\n",
       " 'short',\n",
       " 'sigmoid',\n",
       " 'sigmoid_',\n",
       " 'sign',\n",
       " 'sign_',\n",
       " 'sin',\n",
       " 'sin_',\n",
       " 'sinh',\n",
       " 'sinh_',\n",
       " 'size',\n",
       " 'slogdet',\n",
       " 'smm',\n",
       " 'softmax',\n",
       " 'sort',\n",
       " 'sparse_dim',\n",
       " 'sparse_mask',\n",
       " 'sparse_resize_',\n",
       " 'sparse_resize_and_clear_',\n",
       " 'split',\n",
       " 'split_with_sizes',\n",
       " 'sqrt',\n",
       " 'sqrt_',\n",
       " 'squeeze',\n",
       " 'squeeze_',\n",
       " 'sspaddmm',\n",
       " 'std',\n",
       " 'stft',\n",
       " 'storage',\n",
       " 'storage_offset',\n",
       " 'storage_type',\n",
       " 'stride',\n",
       " 'sub',\n",
       " 'sub_',\n",
       " 'sum',\n",
       " 'svd',\n",
       " 'symeig',\n",
       " 't',\n",
       " 't_',\n",
       " 'take',\n",
       " 'tan',\n",
       " 'tan_',\n",
       " 'tanh',\n",
       " 'tanh_',\n",
       " 'to',\n",
       " 'to_dense',\n",
       " 'to_sparse',\n",
       " 'tolist',\n",
       " 'topk',\n",
       " 'trace',\n",
       " 'transpose',\n",
       " 'transpose_',\n",
       " 'tril',\n",
       " 'tril_',\n",
       " 'triu',\n",
       " 'triu_',\n",
       " 'trtrs',\n",
       " 'trunc',\n",
       " 'trunc_',\n",
       " 'type',\n",
       " 'type_as',\n",
       " 'unbind',\n",
       " 'unfold',\n",
       " 'uniform_',\n",
       " 'unique',\n",
       " 'unsqueeze',\n",
       " 'unsqueeze_',\n",
       " 'values',\n",
       " 'var',\n",
       " 'view',\n",
       " 'view_as',\n",
       " 'where',\n",
       " 'zero_']"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dir(z)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "返回很多，我们直接排除掉一些Python中特殊方法（以__开头和结束的）和私有方法（以_开头的，直接看几个比较主要的属性：\n",
    "`.is_leaf`：记录是否是叶子节点。通过这个属性来确定这个变量的类型\n",
    "在官方文档中所说的“graph leaves”,“leaf variables”，都是指像`x`,`y`这样的手动创建的、而非运算得到的变量，这些变量成为创建变量。\n",
    "像`z`这样的，是通过计算后得到的结果称为结果变量。\n",
    "\n",
    "一个变量是创建变量还是结果变量是通过`.is_leaf`来获取的。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x.is_leaf=True\n",
      "z.is_leaf=False\n"
     ]
    }
   ],
   "source": [
    "print(\"x.is_leaf=\"+str(x.is_leaf))\n",
    "print(\"z.is_leaf=\"+str(z.is_leaf))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`x`是手动创建的没有通过计算，所以他被认为是一个叶子节点也就是一个创建变量，而`z`是通过`x`与`y`的一系列计算得到的，所以不是叶子结点也就是结果变量。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "为什么我们执行`z.backward()`方法会更新`x.grad`和`y.grad`呢？\n",
    "`.grad_fn`属性记录的就是这部分的操作，虽然`.backward()`方法也是CPP实现的，但是可以通过Python来进行简单的探索。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`grad_fn`：记录并且编码了完整的计算历史"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<AddBackward0 at 0x120840a90>"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "z.grad_fn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`grad_fn`是一个`AddBackward0`类型的变量 `AddBackward0`这个类也是用Cpp来写的,但是我们从名字里就能够大概知道，他是加法(ADD)的反反向传播（Backward），看看里面有些什么东西"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['__call__',\n",
       " '__class__',\n",
       " '__delattr__',\n",
       " '__dir__',\n",
       " '__doc__',\n",
       " '__eq__',\n",
       " '__format__',\n",
       " '__ge__',\n",
       " '__getattribute__',\n",
       " '__gt__',\n",
       " '__hash__',\n",
       " '__init__',\n",
       " '__init_subclass__',\n",
       " '__le__',\n",
       " '__lt__',\n",
       " '__ne__',\n",
       " '__new__',\n",
       " '__reduce__',\n",
       " '__reduce_ex__',\n",
       " '__repr__',\n",
       " '__setattr__',\n",
       " '__sizeof__',\n",
       " '__str__',\n",
       " '__subclasshook__',\n",
       " '_register_hook_dict',\n",
       " 'metadata',\n",
       " 'name',\n",
       " 'next_functions',\n",
       " 'register_hook',\n",
       " 'requires_grad']"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dir(z.grad_fn)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`next_functions`就是`grad_fn`的精华"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((<PowBackward0 at 0x1208409b0>, 0), (<PowBackward0 at 0x1208408d0>, 0))"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "z.grad_fn.next_functions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`next_functions`是一个tuple of tuple of PowBackward0 and int。\n",
    "\n",
    "为什么是2个tuple ？\n",
    "因为我们的操作是`z= x**2+y**3` 刚才的`AddBackward0`是相加，而前面的操作是乘方 `PowBackward0`。tuple第一个元素就是x相关的操作记录"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['__call__',\n",
       " '__class__',\n",
       " '__delattr__',\n",
       " '__dir__',\n",
       " '__doc__',\n",
       " '__eq__',\n",
       " '__format__',\n",
       " '__ge__',\n",
       " '__getattribute__',\n",
       " '__gt__',\n",
       " '__hash__',\n",
       " '__init__',\n",
       " '__init_subclass__',\n",
       " '__le__',\n",
       " '__lt__',\n",
       " '__ne__',\n",
       " '__new__',\n",
       " '__reduce__',\n",
       " '__reduce_ex__',\n",
       " '__repr__',\n",
       " '__setattr__',\n",
       " '__sizeof__',\n",
       " '__str__',\n",
       " '__subclasshook__',\n",
       " '_register_hook_dict',\n",
       " 'metadata',\n",
       " 'name',\n",
       " 'next_functions',\n",
       " 'register_hook',\n",
       " 'requires_grad']"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "xg = z.grad_fn.next_functions[0][0]\n",
    "dir(xg)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "继续深挖"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "AccumulateGrad"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_leaf=xg.next_functions[0][0]\n",
    "type(x_leaf)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在PyTorch的反向图计算中，`AccumulateGrad`类型代表的就是叶子节点类型，也就是计算图终止节点。`AccumulateGrad`类中有一个`.variable`属性指向叶子节点。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.1044, 0.6777, 0.2780, 0.5005, 0.4966],\n",
       "        [0.6328, 0.0611, 0.4004, 0.5564, 0.3631],\n",
       "        [0.5526, 0.1290, 0.9003, 0.0772, 0.1823],\n",
       "        [0.9428, 0.6148, 0.9530, 0.4657, 0.0324],\n",
       "        [0.2976, 0.8095, 0.4215, 0.9606, 0.0161]], requires_grad=True)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_leaf.variable"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这个`.variable`的属性就是我们的生成的变量`x`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x_leaf.variable的id:4840553424\n",
      "x的id:4840553424\n"
     ]
    }
   ],
   "source": [
    "print(\"x_leaf.variable的id:\"+str(id(x_leaf.variable)))\n",
    "print(\"x的id:\"+str(id(x)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert(id(x_leaf.variable)==id(x))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这样整个规程就很清晰了：\n",
    "\n",
    "1. 当我们执行z.backward()的时候。这个操作将调用z里面的grad_fn这个属性，执行求导的操作。\n",
    "2. 这个操作将遍历grad_fn的next_functions，然后分别取出里面的Function（AccumulateGrad），执行求导操作。这部分是一个递归的过程直到最后类型为叶子节点。\n",
    "3. 计算出结果以后，将结果保存到他们对应的variable 这个变量所引用的对象（x和y）的 grad这个属性里面。\n",
    "4. 求导结束。所有的叶节点的grad变量都得到了相应的更新\n",
    "\n",
    "最终当我们执行完c.backward()之后，a和b里面的grad值就得到了更新。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 扩展Autograd\n",
    "如果需要自定义autograd扩展新的功能，就需要扩展Function类。因为Function使用autograd来计算结果和梯度，并对操作历史进行编码。\n",
    "在Function类中最主要的方法就是`forward()`和`backward()`他们分别代表了前向传播和反向传播。\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "一个自定义的Function需要一下三个方法：\n",
    "\n",
    "    __init__ (optional)：如果这个操作需要额外的参数则需要定义这个Function的构造函数，不需要的话可以忽略。\n",
    "    \n",
    "    forward()：执行前向传播的计算代码\n",
    "    \n",
    "    backward()：反向传播时梯度计算的代码。 参数的个数和forward返回值的个数一样，每个参数代表传回到此操作的梯度。\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 引入Function便于扩展\n",
    "from torch.autograd.function import Function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义一个乘以常数的操作(输入参数是张量)\n",
    "# 方法必须是静态方法，所以要加上@staticmethod \n",
    "class MulConstant(Function):\n",
    "    @staticmethod \n",
    "    def forward(ctx, tensor, constant):\n",
    "        # ctx 用来保存信息这里类似self，并且ctx的属性可以在backward中调用\n",
    "        ctx.constant=constant\n",
    "        return tensor *constant\n",
    "    @staticmethod\n",
    "    def backward(ctx, grad_output):\n",
    "        # 返回的参数要与输入的参数一样.\n",
    "        # 第一个输入为3x3的张量，第二个为一个常数\n",
    "        # 常数的梯度必须是 None.\n",
    "        return grad_output, None "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "定义完我们的新操作后，我们来进行测试"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "a:tensor([[0.0118, 0.1434, 0.8669],\n",
      "        [0.1817, 0.8904, 0.5852],\n",
      "        [0.7364, 0.5234, 0.9677]], requires_grad=True)\n",
      "b:tensor([[0.0588, 0.7169, 4.3347],\n",
      "        [0.9084, 4.4520, 2.9259],\n",
      "        [3.6820, 2.6171, 4.8386]], grad_fn=<MulConstantBackward>)\n"
     ]
    }
   ],
   "source": [
    "a=torch.rand(3,3,requires_grad=True)\n",
    "b=MulConstant.apply(a,5)\n",
    "print(\"a:\"+str(a))\n",
    "print(\"b:\"+str(b)) # b为a的元素乘以5"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "反向传播，返回值不是标量，所以`backward`方法需要参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "b.backward(torch.ones_like(a))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1., 1., 1.],\n",
       "        [1., 1., 1.],\n",
       "        [1., 1., 1.]])"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a.grad"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "梯度因为1"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ML with tensorflow and Pytorch",
   "language": "python",
   "name": "ml"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  },
  "latex_envs": {
   "LaTeX_envs_menu_present": true,
   "autoclose": false,
   "autocomplete": true,
   "bibliofile": "biblio.bib",
   "cite_by": "apalike",
   "current_citInitial": 1,
   "eqLabelWithNumbers": true,
   "eqNumInitial": 1,
   "hotkeys": {
    "equation": "Ctrl-E",
    "itemize": "Ctrl-I"
   },
   "labels_anchors": false,
   "latex_user_defs": false,
   "report_style_numbering": false,
   "user_envs_cfg": false
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": false,
   "sideBar": true,
   "skip_h1_title": true,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {
    "height": "calc(100% - 180px)",
    "left": "10px",
    "top": "150px",
    "width": "288px"
   },
   "toc_section_display": true,
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
